# 真实函数的参数缺省值为 w=1.2，b=0.5
import torch
from DL.实验3.nndl.linear import Linear
from DL.实验3.nndl.optimizer_lsm import optimizer_lsm
from matplotlib import pyplot as plt  # matplotlib 是 Python 的绘图库
from nndl.create_data import create_toy_data
from DL.实验3.nndl.mean_squared_error import mean_squared_error

def linear_func(x, w=1.2, b=0.5):
    y = w * x + b
    return y


y_true= torch.tensor([[-0.2],[4.9]],dtype=torch.float32)
y_pred = torch.tensor([[1.3],[2.5]],dtype=torch.float32)

error = mean_squared_error(y_true=y_true, y_pred=y_pred).item()
print("error:",error)

# func = linear_func
# interval = (-10, 10)
# train_num = 100  # 训练样本数目
# test_num = 50  # 测试样本数目
# noise = 2
# X_train, y_train = create_toy_data(func=func, interval=interval, sample_num=train_num, noise=noise, add_outlier=False)
# X_test, y_test = create_toy_data(func=func, interval=interval, sample_num=test_num, noise=noise, add_outlier=False)
#
# X_train_large, y_train_large = create_toy_data(func=func, interval=interval, sample_num=5000, noise=noise,
#                                                add_outlier=False)
#
# # torch.linspace返回一个Tensor，Tensor的值为在区间start和stop上均匀间隔的num个值，输出Tensor的长度为num
# X_underlying = torch.linspace(interval[0], interval[1], train_num)
# y_underlying = linear_func(X_underlying)
#
# # 绘制数据
# plt.scatter(X_train, y_train, marker='*', facecolor="none", edgecolor='#e4007f', s=50, label="train data")
# plt.scatter(X_test, y_test, facecolor="none", edgecolor='#f19ec2', s=50, label="test data")
# plt.plot(X_underlying, y_underlying, c='#000000', label=r"underlying distribution")
# plt.legend(fontsize='x-large')  # 给图像加图例
# plt.savefig('ml-vis.pdf')  # 保存图像到PDF文件中
# plt.show()
# input_size = 1
# model = Linear(input_size)
# model = optimizer_lsm(model, X_train.reshape([-1, 1]), y_train.reshape([-1, 1]))
# print("w_pred:", model.params['w'].item(), "b_pred: ", model.params['b'].item())
# y_train_pred = model(X_train.reshape([-1, 1])).squeeze()
# train_error = mean_squared_error(y_true=y_train, y_pred=y_train_pred).item()
# print("train error: ", train_error)
# y_test_pred = model(X_test.reshape([-1, 1])).squeeze()
# test_error = mean_squared_error(y_true=y_test, y_pred=y_test_pred).item()
# print("test error: ", test_error)
